{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Calculate ROC-AUC, AUC-CI and Precision \n",
    "\n",
    "* Calculate Area Under the Receiveing Operator Curve using the network proximity measures ($d_c$ and $Z_{d_c}$)\n",
    "* Calculate 95% confidence intervals using the bootstrap technique with 2,000 resamplings with sample sizes of 150 each\n",
    "* Calculate precision of the top 10 predictions, considering only the polyphenol-disease associations with relative distance $Z_{d_c} < -0.5$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from collections import defaultdict\n",
    "from progressbar import ProgressBar\n",
    "from multiprocessing import Pool\n",
    "import scipy\n",
    "import os\n",
    "import sys\n",
    "sys.path.append('../')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn import metrics\n",
    "from sklearn.metrics import roc_auc_score\n",
    "from sklearn.model_selection import ShuffleSplit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "basedir = os.path.abspath('../') + '/'\n",
    "infolder = basedir + 'data/'\n",
    "outfolder = basedir + 'output/'\n",
    "dbs = basedir + 'data/databases/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import utils.network_utils as network_utils\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "proximity_file = infolder + 'SupplementaryData2.csv'\n",
    "ctd_file = infolder + 'ctd_polyphenols_implicit_explicit.csv'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "ctd = pd.read_csv(ctd_file,index_col = 0)\n",
    "\n",
    "\n",
    "### calculate AUC using each measure in measures\n",
    "measures = ['closest', 'z_closest']\n",
    "\n",
    "# negative measures\n",
    "### measures in which lowers values represent more significance\n",
    "negative_measures = ['closest', 'z_closest']\n",
    "\n",
    "## calculate AUC only for chemicals (or diseases) in selected_chemicals (selected_diseases)\n",
    "selected_chemicals = ['gallic acid']\n",
    "selected_diseases = None\n",
    "\n",
    "target_columns = ['chemical','disease']\n",
    "\n",
    "## column flagging explicit evidence\n",
    "explicit = 'DirectEvidence'\n",
    "\n",
    "## column flagging implicit evidence\n",
    "implicit = 'therapeutic'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(19734, 8)\n",
      "(299, 6)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>chemical</th>\n",
       "      <th>disease</th>\n",
       "      <th>closest</th>\n",
       "      <th>z_closest</th>\n",
       "      <th>DirectEvidence</th>\n",
       "      <th>therapeutic</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>598</th>\n",
       "      <td>gallic acid</td>\n",
       "      <td>liver diseases</td>\n",
       "      <td>1.0</td>\n",
       "      <td>-1.720618</td>\n",
       "      <td>NaN</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>599</th>\n",
       "      <td>gallic acid</td>\n",
       "      <td>lung diseases</td>\n",
       "      <td>2.0</td>\n",
       "      <td>0.661377</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>600</th>\n",
       "      <td>gallic acid</td>\n",
       "      <td>overweight</td>\n",
       "      <td>2.0</td>\n",
       "      <td>-0.030014</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>601</th>\n",
       "      <td>gallic acid</td>\n",
       "      <td>bone marrow diseases</td>\n",
       "      <td>2.0</td>\n",
       "      <td>0.629800</td>\n",
       "      <td>NaN</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>602</th>\n",
       "      <td>gallic acid</td>\n",
       "      <td>tauopathies</td>\n",
       "      <td>2.0</td>\n",
       "      <td>0.416784</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "        chemical               disease  closest  z_closest DirectEvidence  \\\n",
       "598  gallic acid        liver diseases      1.0  -1.720618            NaN   \n",
       "599  gallic acid         lung diseases      2.0   0.661377            NaN   \n",
       "600  gallic acid            overweight      2.0  -0.030014            NaN   \n",
       "601  gallic acid  bone marrow diseases      2.0   0.629800            NaN   \n",
       "602  gallic acid           tauopathies      2.0   0.416784            NaN   \n",
       "\n",
       "     therapeutic  \n",
       "598          1.0  \n",
       "599          NaN  \n",
       "600          NaN  \n",
       "601          1.0  \n",
       "602          NaN  "
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "## file containing proximity calculations\n",
    "dt = pd.read_csv(proximity_file)\n",
    "print (dt.shape)\n",
    "dt['chemical'] = [i.lower() for i in dt.chemical]\n",
    "\n",
    "\n",
    "## merge proximity calculations and CTD associations\n",
    "dx = pd.merge(dt[target_columns + measures], ctd, \n",
    "              on =target_columns, \n",
    "              how='outer')\n",
    "\n",
    "if 'n_mapped_chemical' in dx.columns:\n",
    "    if selected_diseases:\n",
    "        dx = dx[(dx.disease.isin(selected_diseases)) & (~dx.n_mapped_chemical.isnull())]\n",
    "    if selected_chemicals:\n",
    "        dx = dx[(dx.chemical.isin(selected_chemicals)) & (~dx.n_mapped_chemical.isnull())]\n",
    "else:\n",
    "    if selected_diseases:\n",
    "        dx = dx[(dx.disease.isin(selected_diseases))]\n",
    "    if selected_chemicals:\n",
    "        dx = dx[(dx.chemical.isin(selected_chemicals))]\n",
    "\n",
    "dx = dx[~dx[measures[0]].isnull()]\n",
    "print (dx.shape)\n",
    "dx.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Explicit\n",
      "5 known\n",
      "294 unknown\n",
      "Implicit + Explicit\n",
      "42 known\n",
      "257 unknown\n"
     ]
    }
   ],
   "source": [
    "dx['therapeutic'] = dx['therapeutic'].fillna(0)\n",
    "dx['buf'] = dx.DirectEvidence\n",
    "dx.loc[~dx.DirectEvidence.isnull(),'DirectEvidence'] = 1\n",
    "dx.loc[dx.DirectEvidence.isnull(),'DirectEvidence'] = 0\n",
    "del dx['buf']\n",
    "\n",
    "\n",
    "print ('Explicit')\n",
    "print (dx[dx[explicit] == 1].shape[0], 'known')\n",
    "print (dx[dx[explicit] == 0].shape[0], 'unknown')\n",
    "print ('Implicit + Explicit')\n",
    "print (dx[dx.therapeutic == 1].shape[0], 'known')\n",
    "print (dx[dx.therapeutic == 0].shape[0], 'unknown')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculcate_performance(chemical, measures=measures, label = 'therapeutic',\n",
    "                          precision = True):\n",
    "    \n",
    "    dw = dx[dx.chemical == chemical]\n",
    "    \n",
    "    res = defaultdict(dict)\n",
    "    x = 0 ## counter\n",
    "    for col in measures:\n",
    "        sub = dw[[label, col]]\n",
    "        fpr, tpr, thresholds = metrics.roc_curve(1 - sub[label], sub[col])\n",
    "        roc_auc = metrics.auc(fpr, tpr)\n",
    "        ### bootstrap\n",
    "        sub = sub.reset_index()\n",
    "        rng = np.random.RandomState(42)\n",
    "        bootstraps = []\n",
    "        for j in range(2000):\n",
    "            # bootstrap by sampling with replacement on the prediction indices\n",
    "            indices = rng.random_integers(0, len(sub.index) - 1, 150)\n",
    "            boot = sub.loc[indices]\n",
    "            while boot[boot[label] == 1].shape[0] == 0:\n",
    "                indices = rng.random_integers(0, len(sub.index) - 1, 150)\n",
    "                boot = sub.loc[indices]\n",
    "            fpr, tpr, thresholds = metrics.roc_curve(1 - boot[label], boot[col])\n",
    "            roc_auc_b = metrics.auc(fpr, tpr)\n",
    "            bootstraps.append(roc_auc_b)\n",
    "        bootstraps.sort()\n",
    "        s_lower = np.percentile(bootstraps, 2.5)\n",
    "        s_upper = np.percentile(bootstraps, 97.5)\n",
    "        res[x]['value'] = roc_auc\n",
    "        res[x]['measure'] = col\n",
    "        res[x]['ci_upper'] = s_upper\n",
    "        res[x]['ci_lower'] = s_lower\n",
    "        res[x]['error_l'] = roc_auc - s_lower\n",
    "        res[x]['error_u'] = s_upper - roc_auc\n",
    "        \n",
    "        \n",
    "        ## precision - top absolute proximity that z < -0.5\n",
    "        if not 'z' in col:\n",
    "            for ntop in [10, 25, 50]:\n",
    "                sub = dw[[label, col, 'z_%s'%col]]\n",
    "                sub = sub[sub['z_%s'%col] < -0.5]\n",
    "                pre = 0\n",
    "                if not sub.shape[0] == 0:\n",
    "                    sub = sub.sort_values(by = col)\n",
    "                    top = sub.iloc[:ntop]\n",
    "                    tp = top[top[label] == 1].shape[0]\n",
    "                    fp = top[top[label] == 0].shape[0]\n",
    "                    pre = 1.*tp/(tp + fp)\n",
    "                res[x]['prec_relative_top%d'%ntop] = pre\n",
    "            \n",
    "        x = x+1\n",
    "\n",
    "\n",
    "    table = pd.DataFrame.from_dict(res,orient='index')\n",
    "    table['chemical'] = chemical\n",
    "    return (table)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "p = Pool(8)\n",
    "samples = list(set(dx.chemical))\n",
    "res = p.map(calculcate_performance, samples)\n",
    "p.close()\n",
    "df = pd.concat(res)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.concat(res)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>value</th>\n",
       "      <th>measure</th>\n",
       "      <th>ci_upper</th>\n",
       "      <th>ci_lower</th>\n",
       "      <th>error_l</th>\n",
       "      <th>error_u</th>\n",
       "      <th>prec_relative_top10</th>\n",
       "      <th>prec_relative_top25</th>\n",
       "      <th>prec_relative_top50</th>\n",
       "      <th>chemical</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.695711</td>\n",
       "      <td>closest</td>\n",
       "      <td>0.809161</td>\n",
       "      <td>0.580788</td>\n",
       "      <td>0.114923</td>\n",
       "      <td>0.113451</td>\n",
       "      <td>0.5</td>\n",
       "      <td>0.4</td>\n",
       "      <td>0.3</td>\n",
       "      <td>gallic acid</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.603344</td>\n",
       "      <td>z_closest</td>\n",
       "      <td>0.749706</td>\n",
       "      <td>0.455282</td>\n",
       "      <td>0.148062</td>\n",
       "      <td>0.146362</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>NaN</td>\n",
       "      <td>gallic acid</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      value    measure  ci_upper  ci_lower   error_l   error_u  \\\n",
       "0  0.695711    closest  0.809161  0.580788  0.114923  0.113451   \n",
       "1  0.603344  z_closest  0.749706  0.455282  0.148062  0.146362   \n",
       "\n",
       "   prec_relative_top10  prec_relative_top25  prec_relative_top50     chemical  \n",
       "0                  0.5                  0.4                  0.3  gallic acid  \n",
       "1                  NaN                  NaN                  NaN  gallic acid  "
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.to_csv('Performance.csv', index=None)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
